{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Tutorial: Exporting StableHLO from JAX\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][jax-tutorial-colab]\n",
        "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][jax-tutorial-kaggle]\n",
        "\n",
        "## Tutorial Setup\n",
        "\n",
        "### Install required dependencies\n",
        "\n",
        "We'll be using `jax` and `jaxlib` (JAX's XLA package), along with `flax` and `transformers` for some models to export.\n",
        "\n",
        "[jax-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb\n",
        "[jax-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb"
      ],
      "metadata": {
        "id": "_TuAgGNKt5HO"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ENUcO6aML-Nq",
        "collapsed": true
      },
      "outputs": [],
      "source": [
        "!pip install -U jax jaxlib flax transformers"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Define `get_stablehlo_asm` to help with MLIR printing\n",
        "from jax._src.interpreters import mlir as jax_mlir\n",
        "from jax._src.lib.mlir import ir\n",
        "\n",
        "# Returns prettyprint of StableHLO module without large constants\n",
        "def get_stablehlo_asm(module_str):\n",
        "  with jax_mlir.make_ir_context():\n",
        "    stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())\n",
        "    return stablehlo_module.operation.get_asm(large_elements_limit=20)\n",
        "\n",
        "# Disable logging for better tutorial rendering\n",
        "import logging\n",
        "logging.disable(logging.WARNING)"
      ],
      "metadata": {
        "id": "HqjeC_QSugYj",
        "cellView": "form"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "_Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability._"
      ],
      "metadata": {
        "id": "LNlFj80UwX0D"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Export JAX model to StableHLO using `jax.export`\n",
        "\n",
        "In this section we'll export a very basic JAX function to StableHLO.\n",
        "\n",
        "The preferred API for export is `jax.experimental.export`, which uses `jax.jit` under the hood. As a rule-of-thumb, a function must be `jit`-able to be exported to StableHLO."
      ],
      "metadata": {
        "id": "TEfsW69IBp_V"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Export basic JAX model to StableHLO\n",
        "\n",
        "Let's start by exporting a basic `plus` function to StableHLO, using `np.int32` argument types to trace the function.\n",
        "\n",
        "Export requires specifying shapes using `jax.ShapeDtypeStruct`, which are trivial to construct from numpy values."
      ],
      "metadata": {
        "id": "3MzMjjf2iIk2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import jax\n",
        "from jax.experimental import export\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "\n",
        "def plus(x,y):\n",
        "  return jnp.add(x,y)\n",
        "\n",
        "# Create abstract input shapes:\n",
        "inputs = (np.int32(1), np.int32(1),)\n",
        "input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]\n",
        "stablehlo_add = export.export(plus)(*input_shapes).mlir_module()\n",
        "print(get_stablehlo_asm(stablehlo_add))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "v-GN3vPbvFoa",
        "outputId": "b128c08f-3591-4142-fd67-561812bb3d4e"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = \"default\"}, %arg1: tensor<i32> {mhlo.layout_mode = \"default\"}) -> (tensor<i32> {jax.result_info = \"\", mhlo.layout_mode = \"default\"}) {\n",
            "    %0 = stablehlo.add %arg0, %arg1 : tensor<i32>\n",
            "    return %0 : tensor<i32>\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Export Huggingface FlaxResNet18 to StableHLO\n",
        "\n",
        "Now let's look at a simple model that appears in the wild, `resnet18`.\n",
        "\n",
        "This example will export a `flax` model from the huggingface `transformers` ResNet page, [FlaxResNetModel](https://huggingface.co/docs/transformers/en/model_doc/resnet#transformers.FlaxResNetModel). Much of this steps setup was copied from the huggingface documentation.\n",
        "\n",
        "The documentation also states: _\"Finally, this model supports inherent JAX features such as: **Just-In-Time (JIT) compilation** ...\"_ which means it is perfect for export.\n",
        "\n",
        "Similar to our very basic example, our steps for export are:\n",
        "1. Instantiate a callable (model/function) that can be JIT'ed.\n",
        "2. Specify shapes for export using `jax.ShapeDtypeStruct` on numpy values\n",
        "3. Use the JAX `export` API to get a StableHLO module"
      ],
      "metadata": {
        "id": "xZWVAHQzBsEM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoImageProcessor, FlaxResNetModel\n",
        "import jax\n",
        "import numpy as np\n",
        "\n",
        "# Construct flax model with sample inputs\n",
        "\n",
        "resnet18 = FlaxResNetModel.from_pretrained(\"microsoft/resnet-18\", return_dict=False)\n",
        "sample_input = np.random.randn(1, 3, 224, 224)\n",
        "input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)\n",
        "\n",
        "# Export to StableHLO\n",
        "stablehlo_resnet18_export = export.export(resnet18)(input_shape)\n",
        "resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())\n",
        "print(resnet18_stablehlo[:600], \"\\n...\\n\", resnet18_stablehlo[-345:])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "53T7jO-v_6PC",
        "outputId": "0536386d-09d2-49a3-b51c-951c20f6e49b"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  func.func public @main(%arg0: tensor<1x3x224x224xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<1x512x7x7xf32> {mhlo.layout_mode = \"default\"}, tensor<1x512x1x1xf32> {mhlo.layout_mode = \"default\"}) {\n",
            "    %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>\n",
            "    %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n",
            "    %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n",
            "    %3 = stablehlo.constan \n",
            "...\n",
            " }\n",
            "  func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> {\n",
            "    %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
            "    %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1x7x7x512xf32>\n",
            "    %2 = stablehlo.maximum %arg0, %1 : tensor<1x7x7x512xf32>\n",
            "    return %2 : tensor<1x7x7x512xf32>\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Export with dynamic batch size\n",
        "\n",
        "Not let's export that same model with a dynamic batch size!\n",
        "\n",
        "In the first example we used an input shape of `tensor<1x3x224x224xf32>`, specifying strict constraints on the input shape. If we wanted to defer the concerete shapes used in compilation until a later point, we can specify a `symbolic_shape`, in this case we'll export using `tensor<?x3x224x224xf32>`.\n",
        "\n",
        "Symbolic shapes are specified using `export.symbolic_shape`, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: `2,a * a,5` to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an `export.SymbolicScope` to avoid unintentional name clashes."
      ],
      "metadata": {
        "id": "X2MC9F7Zlx6E"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "dyn_scope = export.SymbolicScope()\n",
        "dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape(\"a,3,224,224\", scope=dyn_scope), np.float32)\n",
        "dyn_resnet18_export = export.export(resnet18)(dyn_input_shape)\n",
        "dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())\n",
        "print(dyn_resnet18_stablehlo[:1900], \"\\n...\\n\", dyn_resnet18_stablehlo[-1000:])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sIkbtViEMJ3T",
        "outputId": "165019c7-e771-43d6-c324-6bb4222798ec"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n",
            "  func.func public @main(%arg0: tensor<?x3x224x224xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<?x512x7x7xf32> {mhlo.layout_mode = \"default\"}, tensor<?x512x1x1xf32> {mhlo.layout_mode = \"default\"}) {\n",
            "    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x3x224x224xf32>) -> tensor<i32>\n",
            "    %1 = stablehlo.constant dense<1> : tensor<i32>\n",
            "    %2 = stablehlo.compare  GE, %0, %1,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>\n",
            "    stablehlo.custom_call @shape_assertion(%2, %0) {api_version = 2 : i32, error_message = \"Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.\", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> ()\n",
            "    %3:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<i32>, tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>)\n",
            "    return %3#0, %3#1 : tensor<?x512x7x7xf32>, tensor<?x512x1x1xf32>\n",
            "  }\n",
            "  func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = \"a\"}, %arg1: tensor<?x3x224x224xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<?x512x7x7xf32> {mhlo.layout_mode = \"default\"}, tensor<?x512x1x1xf32> {mhlo.layout_mode = \"default\"}) {\n",
            "    %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>\n",
            "    %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n",
            "    %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n",
            "    %3 = stablehl \n",
            "...\n",
            " <?x14x14x256xf32>\n",
            "    return %10 : tensor<?x14x14x256xf32>\n",
            "  }\n",
            "  func.func private @relu_3(%arg0: tensor<i32> {jax.global_constant = \"a\"}, %arg1: tensor<?x7x7x512xf32>) -> tensor<?x7x7x512xf32> {\n",
            "    %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>\n",
            "    %1 = stablehlo.constant dense<7> : tensor<i32>\n",
            "    %2 = stablehlo.constant dense<7> : tensor<i32>\n",
            "    %3 = stablehlo.constant dense<512> : tensor<i32>\n",
            "    %4 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>\n",
            "    %5 = stablehlo.constant dense<7> : tensor<1xi32>\n",
            "    %6 = stablehlo.constant dense<7> : tensor<1xi32>\n",
            "    %7 = stablehlo.constant dense<512> : tensor<1xi32>\n",
            "    %8 = stablehlo.concatenate %4, %5, %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>\n",
            "    %9 = stablehlo.dynamic_broadcast_in_dim %0, %8, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x7x7x512xf32>\n",
            "    %10 = stablehlo.maximum %arg1, %9 : tensor<?x7x7x512xf32>\n",
            "    return %10 : tensor<?x7x7x512xf32>\n",
            "  }\n",
            "}\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "A few things to note in the exported StableHLO:\n",
        "\n",
        "1. The exported program now has `tensor<?x3x224x224xf32>`. These input types can be refined in many ways: SavedModel execution takes care of refinement which we'll see in the next example, but StableHLO also has APIs to [refine shapes](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L69-L99) and [canonicalize dynamic programs](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L18-L28) to static programs.\n",
        "2. JAX will generate guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined."
      ],
      "metadata": {
        "id": "dRIu7xlSoDUK"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Export to SavedModel\n",
        "\n",
        "It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via [TF Serving](https://github.com/tensorflow/serving).\n",
        "\n",
        "JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section we'll be using our dynamic model from the previous section."
      ],
      "metadata": {
        "id": "xFU5M6Xm1U8_"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Install latest TF\n",
        "\n",
        "SavedModel definition lives in TF, so we need to install the dependency. We recommend using `tensorflow-cpu` or `tf-nightly`."
      ],
      "metadata": {
        "id": "Gt_bIJaSpsYf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install tensorflow-cpu"
      ],
      "metadata": {
        "id": "KZEd7NavBoem",
        "collapsed": true
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Export to SavedModel using `jax2tf`\n",
        "\n",
        "JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in `jax.experimental.jax2tf`. This uses the `export` function under the hood, so the same `jit` requirements apply.\n",
        "\n",
        "Full details on `jax2tf` can be found in the [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#jax-and-tensorflow-interoperation-jax2tfcall_tf), for this example we'll only need to know the `polymorphic_shapes` option to specify our dynamic batch dimension."
      ],
      "metadata": {
        "id": "lf7Fnsrop7BD"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from jax.experimental import jax2tf\n",
        "import tensorflow as tf\n",
        "\n",
        "exported_f = jax2tf.convert(resnet18, polymorphic_shapes=[\"(a,3,224,224)\"])\n",
        "\n",
        "# Copied from the jax2tf README.md > Usage: saved model\n",
        "my_model = tf.Module()\n",
        "my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))\n",
        "tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\n",
        "\n",
        "!ls /tmp/resnet18_tf"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GXkgtX7QEiWa",
        "outputId": "9034836e-4ba1-4210-c2c1-316777d4ad89",
        "collapsed": true
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Reload and call the SavedModel\n",
        "\n",
        "Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n",
        "\n",
        "_Note: the restored model does *not* require JAX to run, just XLA._"
      ],
      "metadata": {
        "id": "CmKABmLdrS3C"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "restored_model = tf.saved_model.load('/tmp/resnet18_tf')\n",
        "restored_result = restored_model.f(tf.constant(sample_input, tf.float32))\n",
        "print(\"Result shape:\", restored_result[0].shape)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9Az3dXXWrVDM",
        "outputId": "cac6d3b1-126e-4e66-dc4d-f2a798ede463"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Result shape: (1, 512, 7, 7)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Common Troubleshooting\n",
        "\n",
        "If the function can be JIT'ed, then it can be exported. Focus on making `jax.jit` work first, or look in desired project for uses of JIT already (ex: [AlphaFold's `apply`](https://github.com/google-deepmind/alphafold/blob/dbe2a438ebfc6289f960292f15dbf421a05e563d/alphafold/model/model.py#L89) can be exported easily).\n",
        "\n",
        "See [JAX JIT Examples](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#Examples) for troubleshooting. The most common issue is control flow, which can often be resolved with `static_argnums` / `static_argnames` as in the linked example.\n",
        "\n",
        "For opening a ticket for help, include a repo using one of the above APIs, this will help get the issue resolved much quicker!"
      ],
      "metadata": {
        "id": "a2Dsm2oF5jn4"
      }
    }
  ]
}
